# -*- coding: utf-8 -*-
"""SAM nonconvex.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/10zRiQ-B5HaXzn4I8aAThVuBfjnnWKFTM
"""

import numpy as np
import matplotlib.pyplot as plt
def learning_rate(A):
    eigenvalues = np.linalg.eigvals(A.transpose() @ A)
    max_eigenval = np.max(np.real(eigenvalues))
    return 1 / max_eigenval

def objective_function(A, b, x):
    """
    Objective function: ∑ log(1 + (Ax - b)^2)
    """
    Ax_b = np.dot(A, x) - b
    return np.sum(np.log(1 + Ax_b**2))

def gradient(A, b, x):
    """
    Gradient of the objective function
    """
    Ax_b = np.dot(A, x) - b
    grad = 2 * np.dot(A.T, Ax_b / (1 + Ax_b**2))
    return grad

def gradient_descent(A, b, x, lr, num_iterations):
    """
    Gradient Descent algorithm
    """
    function_values = [objective_function(A, b, x)]
    for k in range(num_iterations):
        grad = gradient(A, b, x)
        x = x-(lr/(k+1)) * grad
        func_val = objective_function(A, b, x)
        function_values.append(func_val)
    return function_values
def SAM(A,b,x,lr,rho,num_iterations,p):
  funall=[objective_function(A, b, x)]
  for k in range(num_iterations):
    eps=gradient(A,b,x)/np.linalg.norm(gradient(A,b,x))
    r=rho/(k+1)**p
    x=x-(lr/(k+1))*gradient(A,b,x+r*eps)
    funall.append(objective_function(A, b, x))
  return funall

# Example usage

p_values=[1,0.1,0.001]
n_values=[2,20,50,100]
for n in n_values:
  np.random.seed(42)
  A = np.random.rand(n, n)
  b = np.random.rand(n)
  lr=learning_rate(A);
  lr=0.1/n
  x=np.linalg.solve(A.transpose() @ A,A.transpose() @ b)+(0.1/n**2)*np.ones(n);
  maxnum = 100*n
  fig, axs = plt.subplots(1, 3, figsize=(15, 5))  # Creating three subplots in a row
  print(axs)
  funGD = gradient_descent(A, b, x, lr, maxnum)
  funallSAMcon = SAM(A, b, x, lr, rho=0.1, num_iterations=maxnum, p=0)
  axs[0].plot(range(len(funGD)), funGD, label='GD')
  axs[0].plot(range(len(funallSAMcon)), funallSAMcon, marker='x',color='m', label='SAM-constant')
  for p in p_values:
      funallSAM = SAM(A, b, x, lr, rho=0.1, num_iterations=maxnum, p=p)
      string = 'SAM-$p={}$'.format(p)

      if p==p_values[-1]:
        axs[2].plot(range(len(funallSAM)), funallSAM, linewidth=1.5,color='r', label=string)
        axs[0].plot(range(len(funallSAM)), funallSAM, linewidth=1.5,color='r', label=string)
      else:
        axs[0].plot(range(len(funallSAM)), funallSAM, linewidth=1.5, label=string)
  axs[1].plot(range(len(funallSAMcon)), funallSAMcon, linewidth=1.5,color='m',label='SAM-constant')
  axs[0].legend()
  axs[0].set_xlabel('Iteration')
  axs[0].set_ylabel('Function Value')
  axs[0].set_title('Performance on $f(x)=\sum_{i=1}^nlog(1+(Ax-b)_i^2)$')
  axs[0].set_yscale('log')

  axs[1].legend()
  axs[1].set_xlabel('Iteration')
  axs[1].set_ylabel('Function Value')
  axs[1].set_title('Performance on $f(x)=\sum_{i=1}^nlog(1+(Ax-b)_i^2)$')
  axs[1].set_yscale('log')
  axs[2].legend()
  axs[2].set_xlabel('Iteration')
  axs[2].set_ylabel('Function Value')
  axs[2].set_title('Performance on $f(x)=\sum_{i=1}^nlog(1+(Ax-b)_i^2)$')
  axs[2].set_yscale('log')
  plt.tight_layout()  # Adjust subplot parameters to give specified padding
  filename = "SAMnonconvex_n_{}.png".format(n)
  plt.savefig(filename)
  plt.show()